import torch
import torch.nn as nn
import torch.nn.functional as F

import math


class ResizeConvFeatureUpsampler(nn.Module):
    """
    https://distill.pub/2016/deconv-checkerboard/
    """

    def __init__(self, num_scales=1,
                 lowest_feature_resolution=8,
                 out_channels=128,
                 vit_type='vits',
                 no_mono_feature=False,
                 gaussian_downsample=None,
                 monodepth_backbone=False,
                 ):
        super(ResizeConvFeatureUpsampler, self).__init__()

        self.num_scales = num_scales
        self.monodepth_backbone = monodepth_backbone

        self.upsampler = nn.ModuleList()

        vit_feature_channel_dict = {
            'vits': 384,
            'vitb': 768,
            'vitl': 1024
        }

        vit_feature_channel = vit_feature_channel_dict[vit_type]

        if monodepth_backbone:
            vit_feature_channel = 384

        out_channels = out_channels // num_scales

        for i in range(num_scales):
            cnn_feature_channels = 128 - (32 * i)
            mv_transformer_feature_channels = 128 // (2 ** i)
            if no_mono_feature:
                mono_feature_channels = 0
            else:
                mono_feature_channels = vit_feature_channel // (2 ** i)

            in_channels = cnn_feature_channels + \
                mv_transformer_feature_channels + mono_feature_channels

            if monodepth_backbone:
                in_channels = 384

            curr_upsample_factor = lowest_feature_resolution // (2 ** i)

            num_upsample = int(math.log(curr_upsample_factor, 2))

            modules = []
            if num_upsample == 1:
                curr_in_channels = out_channels * 2
            else:
                curr_in_channels = out_channels * 2 * (num_upsample - 1)
            modules.append(nn.Conv2d(in_channels, curr_in_channels, 1))
            for i in range(num_upsample):
                modules.append(nn.Upsample(scale_factor=2, mode='nearest'))

                if i == num_upsample - 1:
                    modules.append(nn.Conv2d(curr_in_channels,
                                             out_channels, 3, 1, 1, padding_mode='replicate'))
                else:
                    modules.append(nn.Conv2d(curr_in_channels,
                                             curr_in_channels // 2, 3, 1, 1, padding_mode='replicate'))
                    curr_in_channels = curr_in_channels // 2
                    modules.append(nn.GELU())

            if gaussian_downsample is not None:
                if gaussian_downsample == 2:
                    del modules[-3:]
                elif gaussian_downsample == 4:
                    del modules[-6:]
                else:
                    raise NotImplementedError

            self.upsampler.append(nn.Sequential(*modules))

    def forward(self, features_list_cnn, features_list_mv, features_list_mono=None):
        out = []

        for i in range(self.num_scales):
            if self.monodepth_backbone:
                concat = features_list_cnn[i]
            elif features_list_mono is None:
                concat = torch.cat(
                (features_list_cnn[i], features_list_mv[i]), dim=1)
            else:
                concat = torch.cat(
                    (features_list_cnn[i], features_list_mv[i], features_list_mono[i]), dim=1)
            concat = self.upsampler[i](concat)

            out.append(concat)

        out = torch.cat(out, dim=1)

        return out


def _test():
    device = torch.device('cuda:0')
    
    model = ResizeConvFeatureUpsampler(num_scales=2,
                                       lowest_feature_resolution=4,
                                       ).to(device)
    print(model)

    b, h, w = 2, 32, 64
    features_list_cnn = [torch.randn(b, 128, h, w).to(device)]
    features_list_mv = [torch.randn(b, 128, h, w).to(device)]
    features_list_mono = [torch.randn(b, 384, h, w).to(device)]

    # scale 2
    features_list_cnn.append(torch.randn(b, 96, h * 2, w * 2).to(device))
    features_list_mv.append(torch.randn(b, 64, h * 2, w * 2).to(device))
    features_list_mono.append(torch.randn(b, 192, h * 2, w * 2).to(device))

    out = model(features_list_cnn,
                features_list_mv, features_list_mono)

    print(out.shape)


if __name__ == '__main__':
    _test()
